package org.misty.spring.redis.filter;

import lombok.Getter;
import lombok.Setter;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;

@ConfigurationProperties("bloom.filter")
@Component
public class RedisBloomFilter {
    @Getter
    @Setter
    private long expectedInsertions;

    @Getter
    @Setter
    private double fpp;

    @Autowired
    private StringRedisTemplate stringRedisTemplate;

    private final static AtomicInteger id = new AtomicInteger(1);
    public final static String luaScript = "local k_bloomFilter = KEYS[1]\n" +
        "local v_indexes = ARGV\n" +
        "for i, v in ipairs(v_indexes) do\n" +
        "  redis.call('setbit', k_bloomFilter, v, 1)\n" +
        "end\n" +
        "return 'OK'";
    private final static RedisScript<String> redisScript = RedisScript.of(luaScript);

    private final String redisKey = "bloomFilter-" + id.getAndIncrement();
    private long numBits;
    private int numHashFunctions;

    @PostConstruct
    public void init() {
        checkArgument(expectedInsertions >= 0, "Expected insertions expectedInsertions(%s) must be >= 0", expectedInsertions);
        checkArgument(fpp > 0.0, "False positive probability fpp(%s) must be > 0.0", fpp);
        checkArgument(fpp < 1.0, "False positive probability fpp(%s) must be < 1.0", fpp);
        this.numBits = optimalNumOfBits(this.expectedInsertions, this.fpp);
        this.numHashFunctions = optimalNumOfHashFunctions(this.expectedInsertions, this.numBits);
    }

    public void put(String key) {
        putLua(key);
    }

    public void putLua(String key) {
        long[] indexes = getIndexes(key);
        var args = Arrays.stream(indexes).mapToObj(String::valueOf).toArray();
        stringRedisTemplate.execute(redisScript, List.of(redisKey), args);
    }

    public boolean isExists(String key) {
        long[] indexes = getIndexes(key);
        List<Object> list = stringRedisTemplate.executePipelined((RedisCallback<Void>) connection -> {
            for (long index : indexes) {
                connection.getBit(redisKey.getBytes(), index);
            }
            return null;
        });
        return !list.contains(false);
    }

    private long[] getIndexes(String key) {
        long[] result = new long[numHashFunctions];
        long hash1 = Objects.hashCode(key), hash2;
        hash2 = (hash2 = hash1 >>> 16) > 0 ? hash2 : (hash1 << 8);
        for (int i = 0; i < numHashFunctions; i++) {
            long combinedHash = hash1 + i * hash2;
            if (combinedHash < 0) {
                combinedHash = ~combinedHash;
            }
            result[i] = combinedHash % numBits;
        }
        return result;
    }

    /*
     * Guava 29.0
     */
    private static long optimalNumOfBits(long n, double p) {
        if (p == 0) {
            p = Double.MIN_VALUE;
        }
        return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
    }

    /*
     * Guava 29.0
     */
    private static int optimalNumOfHashFunctions(long n, long m) {
        // (m / n) * log(2), but avoid truncation due to division!
        return Math.max(1, (int) Math.round((double) m / n * Math.log(2)));
    }

    private static void checkArgument(boolean b, String errorMessageTemplate, @Nullable Object p1) {
        if (!b) {
            throw new IllegalArgumentException(String.format(errorMessageTemplate != null ? errorMessageTemplate : "%s", p1));
        }
    }
}
